from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from transformers.generation import GenerationConfig
import torch
torch.manual_seed(1234)
import json
from azfuse import File
import os
import shortuuid
from peft import AutoPeftModelForCausalLM
from tqdm import tqdm

def load_model(model_path):
    to_prepare = [f for f in File.list(model_path) if File.isfile(f)]
    File.prepare(to_prepare)
    if "lora" in model_path.lower() or "dpo" in model_path.lower():
        cfg = json.load(File.open(os.path.join(model_path, "adapter_config.json")))
        model_base = cfg["base_model_name_or_path"]
        to_prepare = [f for f in File.list(model_base) if File.isfile(f)]
        File.prepare(to_prepare)
        model = AutoPeftModelForCausalLM.from_pretrained(
            model_path, # path to the output directory
            device_map="auto",
            trust_remote_code=True,
            fp16=True
        ).eval()
    else:
        model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", trust_remote_code=True, fp16=True).eval()
    return model

'''
GenerationConfig from pretrained: ./models/Qwen/Qwen-VL-Chat/
{
  "chat_format": "chatml",
  "do_sample": true,
  "eos_token_id": 151643,
  "max_new_tokens": 512,
  "max_window_size": 6144,
  "pad_token_id": 151643,
  "top_k": 0,
  "top_p": 0.3,
  "transformers_version": "4.31.0"
}
'''

@torch.no_grad()
def eval(question_file, image_folder, model_path, answer_file, overwrite=False):
    if File.isfile(answer_file) and not overwrite:
        print(f"File {answer_file} already exists, skipping")
        return
    print(f"Generating answers for {question_file} using model {model_path}")

    model = load_model(model_path)
    # Note: The default behavior now has injection attack prevention off.
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    # Specify hyperparameters for generation
    model.generation_config = GenerationConfig.from_pretrained(
        "./models/Qwen/Qwen-VL-Chat/", do_sample=False)

    if question_file.endswith('.jsonl'):
        questions = [json.loads(q) for q in File.open(os.path.expanduser(question_file), "r")]
    else:
        questions = json.load(File.open(os.path.expanduser(question_file), "r"))

    answer_file = os.path.expanduser(answer_file)
    os.makedirs(os.path.dirname(answer_file), exist_ok=True)

    ans_outputs = []
    
    for line in tqdm(questions):
        qid = line["question_id"]
        image_file = line["image"]
        qs = line["text"]
        query = tokenizer.from_list_format([
            {'image': os.path.join(image_folder, image_file)},
            {'text': qs},
        ])
        response, _ = model.chat(tokenizer, query=query, history=None)
        ans_id = shortuuid.uuid()
        ans_outputs.append({"question_id": qid,
                        "prompt": qs,
                        "text": response,
                        "answer_id": ans_id,
                        "model_id": model_path,
                        "metadata": {}})
    with File.open(answer_file, "w") as f:
        for d in ans_outputs:
            f.write(json.dumps(d)+"\n")


from typing import Tuple, List
from transformers import PreTrainedTokenizer


def get_stop_words_ids(chat_format, tokenizer):
    if chat_format == "raw":
        stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
    elif chat_format == "chatml":
        stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
    else:
        raise NotImplementedError(f"Unknown chat format {chat_format!r}")
    return stop_words_ids


def make_context(
    tokenizer: PreTrainedTokenizer,
    query: str,
    history: List[Tuple[str, str]] = None,
    system: str = "",
    max_window_size: int = 6144,
    chat_format: str = "chatml",
):
    if history is None:
        history = []

    if chat_format == "chatml":
        im_start, im_end = "<|im_start|>", "<|im_end|>"
        im_start_tokens = [tokenizer.im_start_id]
        im_end_tokens = [tokenizer.im_end_id]
        nl_tokens = tokenizer.encode("\n")

        def _tokenize_str(role, content):
            return f"{role}\n{content}", tokenizer.encode(
                role, allowed_special=set(tokenizer.IMAGE_ST)
            ) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))

        system_text, system_tokens_part = _tokenize_str("system", system)
        system_tokens = im_start_tokens + system_tokens_part + im_end_tokens

        raw_text = ""
        context_tokens = []

        for turn_query, turn_response in reversed(history):
            query_text, query_tokens_part = _tokenize_str("user", turn_query)
            query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
            if turn_response is not None:
                response_text, response_tokens_part = _tokenize_str(
                    "assistant", turn_response
                )
                response_tokens = im_start_tokens + response_tokens_part + im_end_tokens

                next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
                prev_chat = (
                    f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
                )
            else:
                next_context_tokens = nl_tokens + query_tokens + nl_tokens
                prev_chat = f"\n{im_start}{query_text}{im_end}\n"

            current_context_size = (
                len(system_tokens) + len(next_context_tokens) + len(context_tokens)
            )
            if current_context_size < max_window_size:
                context_tokens = next_context_tokens + context_tokens
                raw_text = prev_chat + raw_text
            else:
                break

        context_tokens = system_tokens + context_tokens
        raw_text = f"{im_start}{system_text}{im_end}" + raw_text
        context_tokens += (
            nl_tokens
            + im_start_tokens
            + _tokenize_str("user", query)[1]
            + im_end_tokens
            + nl_tokens
            + im_start_tokens
            + tokenizer.encode("assistant")
            + nl_tokens
        )
        raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"

    elif chat_format == "raw":
        raw_text = query
        context_tokens = tokenizer.encode(raw_text)
    else:
        raise NotImplementedError(f"Unknown chat format {chat_format!r}")

    return raw_text, context_tokens


@torch.no_grad()
def get_pred_prob_only_yes_or_no(question_file, pred_file, image_folder, model_path, answer_file, overwrite=False):
    if File.isfile(answer_file) and not overwrite:
        print(f"File {answer_file} already exists, skipping")
        return
    print(f"Generating answers for {question_file} using model {model_path}")

    model = load_model(model_path)
    # Note: The default behavior now has injection attack prevention off.
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    # Specify hyperparameters for generation
    model.generation_config = GenerationConfig.from_pretrained(
        "./models/Qwen/Qwen-VL-Chat/", max_new_tokens=1, do_sample=False)

    if question_file.endswith('.jsonl'):
        questions = [json.loads(q) for q in File.open(os.path.expanduser(question_file), "r")]
    else:
        questions = json.load(File.open(os.path.expanduser(question_file), "r"))
    qid2q = {q["question_id"]: q for q in questions}

    if pred_file.endswith('.jsonl'):
        predictions = [json.loads(q) for q in File.open(os.path.expanduser(pred_file), "r")]
    else:
        predictions = json.load(File.open(os.path.expanduser(pred_file), "r"))

    answer_file = os.path.expanduser(answer_file)
    os.makedirs(os.path.dirname(answer_file), exist_ok=True)

    ans_outputs = []
    
    for pred_line in tqdm(predictions):
        qid = pred_line["question_id"]
        line = qid2q[qid]
        image_file = line["image"]
        qs = pred_line["prompt"]
        pred_ans = pred_line["text"]
        cur_prompt = "Question: " + qs + "\nAnswer: " + pred_ans + "\nIs the answer to the question correct, yes or no?"
        query = tokenizer.from_list_format([
            {'image': os.path.join(image_folder, image_file)},
            {'text': cur_prompt},
        ])
        # response, _ = model.chat(tokenizer, query=query, history=None)
        history = []
        stop_words_ids = []

        max_window_size = None
        if max_window_size is None:
            max_window_size = model.generation_config.max_window_size
        raw_context_text, context_tokens = make_context(
            tokenizer,
            query,
            history=history,
            system="You are a helpful assistant.",
            max_window_size=max_window_size,
            chat_format=model.generation_config.chat_format,
        )

        stop_words_ids.extend(get_stop_words_ids(
            model.generation_config.chat_format, tokenizer
        ))
        input_ids = torch.tensor([context_tokens]).to(model.device)
        out = model.generate(
                    input_ids,
                    stop_words_ids=stop_words_ids,
                    return_dict_in_generate=True,
                    output_scores=True,
                    use_cache=True,
                    generation_config=model.generation_config,
                )

        sequences = out.sequences
        scores = out.scores

        special_tokens=list([tokenizer.bos_token_id, tokenizer.pad_token_id, tokenizer.eos_token_id])
        sequences = sequences[:, input_ids.shape[1]:]
        generated_text = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0].strip()
        seq_probs, seq_log_probs = list(), list()
        for si, el in enumerate(scores):
            #print (i, si, len(scores))
            try:
                if sequences[0][si+1] in special_tokens:
                    # print (f"{si}, SKIP")
                    continue
            except IndexError:
                if sequences[0][si] in special_tokens:
                    # print (f"{si}, SKIP")
                    continue
            
            token_logprob = el
            token_probs = torch.nn.functional.softmax(token_logprob, dim=1)
            max_prob = torch.max(token_probs)
            idx = torch.argmax(token_probs)
            # print (si, idx)
            max_log_prob = torch.log(max_prob)
            seq_probs.append(max_prob.item())
            seq_log_probs.append(max_log_prob.item())
        if len(seq_probs) > 0:
            avg_prob = sum(seq_probs) / len(seq_probs)
            avg_log_prob = sum(seq_log_probs) / len(seq_log_probs)
            # answer prob is the product of all elements in seq_probs
            answer_prob = 1.0
            for p in seq_probs:
                answer_prob *= p
        else:
            avg_prob, avg_log_prob = 0.0, 0.0
        ans_w_prob = {"text": generated_text, "average_prob": avg_prob, "average_log_prob": avg_log_prob, "text_prob": answer_prob}
        lower_text = generated_text.lower()
        if lower_text not in ["yes", "no"]:
            print(f"Generated text is not yes or no: {lower_text}")
        
        cap_yes_tok_idx = tokenizer(["Yes"]).input_ids[0]
        cap_no_tok_idx = tokenizer(["No"]).input_ids[0]
        yes_tok_idx = tokenizer(["yes"]).input_ids[0]
        no_tok_idx = tokenizer(["no"]).input_ids[0]
        if isinstance(yes_tok_idx, list):
            if len(yes_tok_idx) > 1:
                yes_tok_idx = yes_tok_idx[1]
                no_tok_idx = no_tok_idx[1]
                cap_yes_tok_idx = cap_yes_tok_idx[1]
                cap_no_tok_idx = cap_no_tok_idx[1]
            else:
                yes_tok_idx = yes_tok_idx[0]
                no_tok_idx = no_tok_idx[0]
                cap_yes_tok_idx = cap_yes_tok_idx[0]
                cap_no_tok_idx = cap_no_tok_idx[0]
        
        token_probs = torch.nn.functional.softmax(scores[0], dim=1)

        yes_prob = token_probs[:, yes_tok_idx].item()
        no_prob = token_probs[:, no_tok_idx].item()
        cap_yes_prob = token_probs[:, cap_yes_tok_idx].item()
        cap_no_prob = token_probs[:, cap_no_tok_idx].item()

        yes_prob = max(yes_prob, cap_yes_prob)
        no_prob = max(no_prob, cap_no_prob)
        
        scaled_yes_prob = yes_prob / (yes_prob + no_prob)
        scaled_no_prob = no_prob / (yes_prob + no_prob)
        ans_w_prob["yes_prob"] = scaled_yes_prob
        ans_w_prob["no_prob"] =  scaled_no_prob
        # import ipdb; ipdb.set_trace()
        # serialize ans_w_prob, convert Tensor to float
        for k, v in ans_w_prob.items():
            if isinstance(v, torch.Tensor):
                ans_w_prob[k] = float(v)
        ans_id = shortuuid.uuid()
        ans_item ={
            "question_id": qid,
            "question": qs,
            "prompt": cur_prompt,
            "answer_id": ans_id,
            "model_id": model_path,
            "metadata": {}}
        ans_item.update(ans_w_prob)
        ans_outputs.append(ans_item)
    with File.open(answer_file, "w") as f:
        for d in ans_outputs:
            f.write(json.dumps(d)+"\n")





def split_list(lst, n):
    import math
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


@torch.no_grad()
def eval_w_chunk(question_file, image_folder, model_path, answer_file, num_chunks, chunk_idx, overwrite=False):
    if File.isfile(answer_file) and not overwrite:
        print(f"File {answer_file} already exists, skipping")
        return
    print(f"Generating answers for {question_file} using model {model_path}")

    model = load_model(model_path)
    # Note: The default behavior now has injection attack prevention off.
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    # Specify hyperparameters for generation
    model.generation_config = GenerationConfig.from_pretrained(
        "./models/Qwen/Qwen-VL-Chat/", do_sample=False)

    if question_file.endswith('.jsonl'):
        questions = [json.loads(q) for q in File.open(os.path.expanduser(question_file), "r")]
    else:
        questions = json.load(File.open(os.path.expanduser(question_file), "r"))

    questions = get_chunk(questions, num_chunks, chunk_idx)

    answer_file = os.path.expanduser(answer_file)
    os.makedirs(os.path.dirname(answer_file), exist_ok=True)

    ans_outputs = []
    
    for line in tqdm(questions):
        qid = line["question_id"]
        image_file = line["image"]
        qs = line["text"]
        query = tokenizer.from_list_format([
            {'image': os.path.join(image_folder, image_file)},
            {'text': qs},
        ])
        response, _ = model.chat(tokenizer, query=query, history=None)
        ans_id = shortuuid.uuid()
        ans_outputs.append({"question_id": qid,
                        "prompt": qs,
                        "text": response,
                        "answer_id": ans_id,
                        "model_id": model_path,
                        "metadata": {}})
    with File.open(answer_file, "w") as f:
        for d in ans_outputs:
            f.write(json.dumps(d)+"\n")


def merge_files(input_folder, num_chunks, output_file):
    import time
    idx = list(range(num_chunks))
    file_list = [f"{input_folder}/{num_chunks}_{i}.jsonl" for i in idx]
    for file in file_list:
        if not File.isfile(file):
            # wait
            time.sleep(10)
    to_prepare = [f for f in file_list]
    File.prepare(to_prepare)
    with File.open(output_file, "w") as out:
        for file in file_list:
            with File.open(file, "r") as f:
                lines = f.readlines()
                for l in lines:
                    out.write(l)

if __name__ == "__main__":
    import fire
    fire
    fire.Fire()
